{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "45ca91c2",
   "metadata": {},
   "source": [
    "# AI tool to generate unit tests for the provided Java code\n",
    "\n",
    "Here we build a Gradio App that uses the frontier models to generate unit tests for a java code. For testing purposes I have used the *cheaper* versions of the models, not the ones the leaderboards indicate as the best ones."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f44901f5",
   "metadata": {},
   "outputs": [],
   "source": [
    "# imports\n",
    "\n",
    "import os\n",
    "from dotenv import load_dotenv\n",
    "from openai import OpenAI\n",
    "import google.generativeai as genai\n",
    "import anthropic\n",
    "import gradio as gr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c47706b3",
   "metadata": {},
   "outputs": [],
   "source": [
    "# environment\n",
    "\n",
    "load_dotenv(override=True)\n",
    "openai_api_key = os.getenv('OPENAI_API_KEY')\n",
    "anthropic_api_key = os.getenv('ANTHROPIC_API_KEY')\n",
    "google_api_key = os.getenv('GOOGLE_API_KEY')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "35446b9a",
   "metadata": {},
   "outputs": [],
   "source": [
    "openai = OpenAI()\n",
    "claude = anthropic.Anthropic()\n",
    "genai.configure()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0e899efd",
   "metadata": {},
   "outputs": [],
   "source": [
    "OPENAI_MODEL = \"gpt-4o-mini\"\n",
    "CLAUDE_MODEL = \"claude-3-haiku-20240307\"\n",
    "GEMINI_MODEL = 'gemini-2.0-flash-lite'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "47640f53",
   "metadata": {},
   "outputs": [],
   "source": [
    "system_message = \"You are an assistant that generates unit test for java code. \"\n",
    "system_message += \"Generate one JUnit5 test class with all the relevant test cases in it.\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f41ccbf0",
   "metadata": {},
   "outputs": [],
   "source": [
    "def user_prompt_for(code):\n",
    "    user_prompt = \"Generate unit tests for this java code.\\n\\n\"\n",
    "    user_prompt += code\n",
    "    return user_prompt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c57c0000",
   "metadata": {},
   "outputs": [],
   "source": [
    "test_code = \"\"\"\n",
    "package com.hma.kafkaproducertest.rest;\n",
    "\n",
    "import com.hma.kafkaproducertest.model.TestDTO;\n",
    "import com.hma.kafkaproducertest.producer.TestProducer;\n",
    "import org.springframework.web.bind.annotation.*;\n",
    "\n",
    "@RestController\n",
    "@RequestMapping(\"/api\")\n",
    "public class TestController {\n",
    "\n",
    "    private final TestProducer producer;\n",
    "\n",
    "    public TestController(TestProducer producer) {\n",
    "        this.producer = producer;\n",
    "    }\n",
    "\n",
    "    @PostMapping(\"/event\")\n",
    "    public TestDTO triggerKafkaEvent(@RequestBody TestDTO payload) {\n",
    "        producer.sendMessage(payload, \"test\");\n",
    "        return payload;\n",
    "    }\n",
    "\n",
    "}\n",
    "\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "00c71128",
   "metadata": {},
   "outputs": [],
   "source": [
    "def stream_gpt(code):\n",
    "    messages = [\n",
    "        {\"role\": \"system\", \"content\": system_message},\n",
    "        {\"role\": \"user\", \"content\": user_prompt_for(code)}\n",
    "      ]\n",
    "    stream = openai.chat.completions.create(\n",
    "        model=OPENAI_MODEL,\n",
    "        messages=messages,\n",
    "        stream=True\n",
    "    )\n",
    "    result = \"\"\n",
    "    for chunk in stream:\n",
    "        result += chunk.choices[0].delta.content or \"\"\n",
    "        yield result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ca92f8a8",
   "metadata": {},
   "outputs": [],
   "source": [
    "def stream_claude(code):\n",
    "    result = claude.messages.stream(\n",
    "        model=CLAUDE_MODEL,\n",
    "        max_tokens=2000,\n",
    "        system=system_message,\n",
    "        messages=[\n",
    "            {\"role\": \"user\", \"content\": user_prompt_for(code)},\n",
    "        ],\n",
    "    )\n",
    "    response = \"\"\n",
    "    with result as stream:\n",
    "        for text in stream.text_stream:\n",
    "            response += text or \"\"\n",
    "            yield response"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9dffed4b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def stream_gemini(code):\n",
    "    gemini = genai.GenerativeModel(\n",
    "        model_name=GEMINI_MODEL,\n",
    "        system_instruction=system_message\n",
    "    )\n",
    "    stream = gemini.generate_content(user_prompt_for(code), stream=True)\n",
    "    result = \"\"\n",
    "    for chunk in stream:\n",
    "        result += chunk.text or \"\"\n",
    "        yield result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "31f9c267",
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_tests(code, model):\n",
    "    if model==\"GPT\":\n",
    "        result = stream_gpt(code)\n",
    "    elif model==\"Claude\":\n",
    "        result = stream_claude(code)\n",
    "    elif model==\"Gemini\":\n",
    "        result = stream_gemini(code)\n",
    "    else:\n",
    "        raise ValueError(\"Unknown model\")\n",
    "    yield from result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c04c0a1b",
   "metadata": {},
   "outputs": [],
   "source": [
    "with gr.Blocks() as ui:\n",
    "    with gr.Row():\n",
    "        original_code = gr.Textbox(label=\"Java code:\", lines=10, value=test_code)\n",
    "        generated_code = gr.Markdown(label=\"Unit tests:\")\n",
    "    with gr.Row():\n",
    "        model = gr.Dropdown([\"GPT\", \"Claude\", \"Gemini\"], label=\"Select model\", value=\"GPT\")\n",
    "        generate = gr.Button(\"Generate tests\")\n",
    "\n",
    "    generate.click(generate_tests, inputs=[original_code, model], outputs=[generated_code])\n",
    "\n",
    "ui.launch(inbrowser=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "84d33a5f",
   "metadata": {},
   "outputs": [],
   "source": [
    "ui.close()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bbd50bf7",
   "metadata": {},
   "source": [
    "## Conclusion\n",
    "\n",
    "The models are missing some information as the `TestDTO` is not defined in the code provided as an input.\n",
    "\n",
    "Results:\n",
    "- Gemini: Generates a well constructed test class with multiple test cases covering scenarios with valid and invalid inputs. It makes assumptions about the content of `TestDTO` and adds a note about those as a comment.\n",
    "- Claude: Similar approach to unknown format of `TestDTO`, although no comment added about the assumptions made. The test cases are strutured differently, and they don't cover any case of invalid input, which in my opinion is an important test for a REST endpoint.\n",
    "- GPT: While the other two generated *real* unit tests using the mockito extension, GPT generated a *webMVC* test. The other two relied on the equality impelemntation of `TestDTO`, while GPT checks separately each field in the response. As this type of test spins up the application context, the test won't run without additional configuration. In addition, some imports are missing from the test file.\n",
    "\n",
    "It comes down to personal preferences, but I would give the point to Gemini for this one."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "llms",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
